#!/usr/bin/env python3
"""
Phase 3: Observable Regime Classification
Tables 6-9: Governance interactions, predicted transition logit, quintile splits, mediation.
"""

import sys
import pandas as pd
import numpy as np
from pathlib import Path
from scipy import stats as sp_stats

PROJECT_DIR = Path("/mnt/c/demographics_capital_flows")
sys.path.insert(0, str(PROJECT_DIR / "multilateral" / "src"))
from model import PanelGLS

CCA_DIR = PROJECT_DIR / "cca_tipping"
PROCESSED_DIR = CCA_DIR / "data" / "processed"
TABLE_DIR = CCA_DIR / "output" / "tables"
TABLE_DIR.mkdir(parents=True, exist_ok=True)

# ── Helpers ────────────────────────────────────────────────────────────────
def star(p):
    if p < 0.01: return "***"
    elif p < 0.05: return "**"
    elif p < 0.10: return "*"
    return ""

def fmt(c, p):
    return f"{c:.2f}{star(p)}"

def run_gls(df, dep_var, indep_vars):
    comp = df.dropna(subset=[dep_var] + indep_vars).copy()
    if len(comp) < 50:
        return None
    y = comp[dep_var].values
    X = comp[indep_vars].values
    gls = PanelGLS()
    gls.fit(y, X, comp['iso3'].values, comp['year'].values)
    return {
        'r_squared': gls.r_squared,
        'n_obs': gls.n_obs,
        'n_countries': gls.n_countries,
        'coefficients': dict(zip(indep_vars, gls.beta)),
        'std_errors': dict(zip(indep_vars, gls.se)),
        'p_values': dict(zip(indep_vars, gls.pvalues)),
    }

# ── Load Data ──────────────────────────────────────────────────────────────
print("=" * 70)
print("PHASE 3: OBSERVABLE REGIME CLASSIFICATION")
print("=" * 70)

panel = pd.read_csv(PROCESSED_DIR / "cca_panel.csv")
est = panel[
    (panel['ca_gdp'].notna()) &
    (panel['year'] >= 1986) &
    (panel['year'] <= 2024)
].copy()

demo_vars = ['Z_1', 'Z_2', 'Z_3']
baseline_controls = ['fiscal_bal_gdp', 'kaopen', 'expected_growth',
                     'nfa_gdp_lag', 'log_rel_opw', 'health_exp_gdp']
base_vars = demo_vars + baseline_controls

# ═══════════════════════════════════════════════════════════════════════════
# TABLE 6: GOVERNANCE / KAOPEN / FINANCIAL DEPTH INTERACTIONS
# ═══════════════════════════════════════════════════════════════════════════
print("\n" + "=" * 70)
print("TABLE 6: INTERACTION TESTS")
print("=" * 70)

# Ensure governance_composite is demeaned for interaction interpretability
gov_mean = est['governance_composite'].mean()
est['gov_dm'] = est['governance_composite'] - gov_mean
est['kaopen_dm'] = est['kaopen'] - est['kaopen'].mean()

# Create interactions
for z in demo_vars:
    est[f'{z}_x_gov'] = est[z] * est['gov_dm']
    est[f'{z}_x_kaopen_dm'] = est[z] * est['kaopen_dm']

interaction_specs = []

# Spec 1: Baseline
interaction_specs.append(("Baseline", base_vars))

# Spec 2: + Z₁ × governance
gov_int_vars = [f'{z}_x_gov' for z in demo_vars]
interaction_specs.append(("+ Z × governance", base_vars + ['gov_dm'] + gov_int_vars))

# Spec 3: + Z₁ × KAOPEN (demeaned)
kaopen_int_vars = [f'{z}_x_kaopen_dm' for z in demo_vars]
interaction_specs.append(("+ Z × KAOPEN (dm)", base_vars + ['kaopen_dm'] + kaopen_int_vars))

# Spec 4: Both interactions
interaction_specs.append(("+ Z × gov + Z × KAOPEN",
                          base_vars + ['gov_dm', 'kaopen_dm'] + gov_int_vars + kaopen_int_vars))

# Spec 5: + CCA dummy for comparison
interaction_specs.append(("+ CCA dummy", base_vars + ['is_cca']))

# Spec 6: + Z × governance + CCA dummy — does governance subsume CCA?
interaction_specs.append(("+ Z × gov + CCA dummy",
                          base_vars + ['gov_dm', 'is_cca'] + gov_int_vars))

int_rows = []
for spec_name, spec_vars in interaction_specs:
    r = run_gls(est, 'ca_gdp', spec_vars)
    if r is None:
        print(f"  {spec_name}: FAILED (insufficient obs)")
        continue
    row = {'specification': spec_name, 'n_countries': r['n_countries'],
           'n_obs': r['n_obs'], 'r_squared': r['r_squared']}
    for v in spec_vars:
        row[f'{v}_coef'] = r['coefficients'].get(v, np.nan)
        row[f'{v}_se'] = r['std_errors'].get(v, np.nan)
        row[f'{v}_pval'] = r['p_values'].get(v, np.nan)
    int_rows.append(row)

    z1c = r['coefficients']['Z_1']
    z1p = r['p_values']['Z_1']
    extra_info = ""
    if 'Z_1_x_gov' in r['coefficients']:
        extra_info += f", Z₁×gov={r['coefficients']['Z_1_x_gov']:.2f}{star(r['p_values']['Z_1_x_gov'])}"
    if 'is_cca' in r['coefficients']:
        extra_info += f", CCA={r['coefficients']['is_cca']:.2f}{star(r['p_values']['is_cca'])}"
    print(f"  {spec_name:<35s}: Z₁={z1c:7.2f}{star(z1p)} (p={z1p:.4f}), R²={r['r_squared']:.4f}{extra_info}")

int_df = pd.DataFrame(int_rows)
int_df.to_csv(TABLE_DIR / "table6_interactions.csv", index=False)

# Markdown
md = ["# Table 6: Governance and KAOPEN Interactions\n"]
key_vars = ['Z_1', 'Z_2', 'Z_3'] + ['gov_dm'] + gov_int_vars + ['kaopen_dm'] + kaopen_int_vars + ['is_cca']
col_labels = [r['specification'] for r in int_rows]
md.append("| Variable | " + " | ".join(col_labels) + " |")
md.append("|---|" + "---|" * len(col_labels))
for v in key_vars:
    cells = []
    for r in int_rows:
        c = r.get(f'{v}_coef', np.nan)
        p = r.get(f'{v}_pval', np.nan)
        se = r.get(f'{v}_se', np.nan)
        if pd.notna(c):
            cells.append(f"{c:.2f}{star(p)} ({se:.2f})")
        else:
            cells.append("—")
    md.append(f"| {v} | " + " | ".join(cells) + " |")
# Footer
md.append("|---|" + "---|" * len(col_labels))
for stat in ['n_countries', 'n_obs', 'r_squared']:
    cells = [f"{r[stat]:.4f}" if stat == 'r_squared' else str(r[stat]) for r in int_rows]
    md.append(f"| {stat} | " + " | ".join(cells) + " |")
(TABLE_DIR / "table6_interactions.md").write_text("\n".join(md))

# ═══════════════════════════════════════════════════════════════════════════
# TABLE 7: PREDICTED TRANSITION LOGIT
# ═══════════════════════════════════════════════════════════════════════════
print("\n" + "=" * 70)
print("TABLE 7: PREDICTED TRANSITION PROBABILITY")
print("=" * 70)

from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler

# Predict is_transition from observables
logit_vars = ['governance_composite', 'kaopen', 'gdp_pc_ppp']
logit_sample = est.dropna(subset=logit_vars + ['is_transition', 'ca_gdp'] + base_vars).copy()
logit_sample['log_gdp_pc'] = np.log(logit_sample['gdp_pc_ppp'].clip(lower=100))

X_logit = logit_sample[['governance_composite', 'kaopen', 'log_gdp_pc']].values
y_logit = logit_sample['is_transition'].values

scaler = StandardScaler()
X_scaled = scaler.fit_transform(X_logit)

lr = LogisticRegression(max_iter=1000)
lr.fit(X_scaled, y_logit)

# Predicted probability
logit_sample['pred_transition'] = lr.predict_proba(X_scaled)[:, 1]
accuracy = (lr.predict(X_scaled) == y_logit).mean()
print(f"  Logit accuracy: {accuracy:.3f}")
print(f"  Coefficients: gov={lr.coef_[0][0]:.3f}, kaopen={lr.coef_[0][1]:.3f}, log_gdp={lr.coef_[0][2]:.3f}")

# Now use pred_transition as continuous regime indicator
logit_sample['pred_trans_dm'] = logit_sample['pred_transition'] - logit_sample['pred_transition'].mean()
for z in demo_vars:
    logit_sample[f'{z}_x_pred_trans'] = logit_sample[z] * logit_sample['pred_trans_dm']

pred_int_vars = [f'{z}_x_pred_trans' for z in demo_vars]
pred_spec = base_vars + ['pred_trans_dm'] + pred_int_vars

r_pred = run_gls(logit_sample, 'ca_gdp', pred_spec)
if r_pred:
    print(f"\n  With predicted transition interaction:")
    print(f"  Z₁={r_pred['coefficients']['Z_1']:.2f}{star(r_pred['p_values']['Z_1'])}, "
          f"Z₁×pred_trans={r_pred['coefficients']['Z_1_x_pred_trans']:.2f}"
          f"{star(r_pred['p_values']['Z_1_x_pred_trans'])}")

# Also run with CCA dummy to compare
pred_spec_cca = pred_spec + ['is_cca']
r_pred_cca = run_gls(logit_sample, 'ca_gdp', pred_spec_cca)
if r_pred_cca:
    print(f"  + CCA dummy: Z₁={r_pred_cca['coefficients']['Z_1']:.2f}{star(r_pred_cca['p_values']['Z_1'])}, "
          f"CCA={r_pred_cca['coefficients']['is_cca']:.2f}{star(r_pred_cca['p_values']['is_cca'])}")

# Save logit results
logit_results = {
    'accuracy': accuracy,
    'coef_governance': lr.coef_[0][0],
    'coef_kaopen': lr.coef_[0][1],
    'coef_log_gdp': lr.coef_[0][2],
    'intercept': lr.intercept_[0],
}
pd.DataFrame([logit_results]).to_csv(TABLE_DIR / "table7_logit_model.csv", index=False)

# Save interaction results
t7_rows = []
for name, r in [("+ pred_transition interactions", r_pred),
                ("+ pred_transition + CCA", r_pred_cca)]:
    if r is None:
        continue
    row = {'specification': name, 'n_countries': r['n_countries'],
           'n_obs': r['n_obs'], 'r_squared': r['r_squared']}
    for v in ['Z_1', 'Z_2', 'Z_3', 'pred_trans_dm'] + pred_int_vars + ['is_cca']:
        row[f'{v}_coef'] = r['coefficients'].get(v, np.nan)
        row[f'{v}_pval'] = r['p_values'].get(v, np.nan)
    t7_rows.append(row)
pd.DataFrame(t7_rows).to_csv(TABLE_DIR / "table7_predicted_transition.csv", index=False)

md = ["# Table 7: Predicted Transition Probability as Regime Indicator\n"]
md.append(f"Logit model: is_transition = f(governance, KAOPEN, log_GDP_pc)")
md.append(f"Accuracy: {accuracy:.3f}\n")
md.append("| Variable | + Pred. transition | + Pred. trans. + CCA |")
md.append("|---|---|---|")
for v in ['Z_1', 'Z_2', 'Z_3', 'pred_trans_dm', 'Z_1_x_pred_trans', 'Z_2_x_pred_trans',
          'Z_3_x_pred_trans', 'is_cca']:
    cells = []
    for r in [r_pred, r_pred_cca]:
        if r and v in r['coefficients']:
            c, p = r['coefficients'][v], r['p_values'][v]
            cells.append(f"{c:.2f}{star(p)} ({r['std_errors'][v]:.2f})")
        else:
            cells.append("—")
    md.append(f"| {v} | " + " | ".join(cells) + " |")
(TABLE_DIR / "table7_predicted_transition.md").write_text("\n".join(md))

# ═══════════════════════════════════════════════════════════════════════════
# TABLE 8: QUINTILE SPLITS BY GOVERNANCE
# ═══════════════════════════════════════════════════════════════════════════
print("\n" + "=" * 70)
print("TABLE 8: QUINTILE SPLITS BY GOVERNANCE COMPOSITE")
print("=" * 70)

gov_sample = est.dropna(subset=base_vars + ['ca_gdp', 'governance_composite']).copy()
gov_sample['gov_quintile'] = pd.qcut(gov_sample['governance_composite'], 5, labels=False) + 1

quint_rows = []
for q in sorted(gov_sample['gov_quintile'].unique()):
    sub = gov_sample[gov_sample['gov_quintile'] == q]
    gov_range = f"[{sub['governance_composite'].min():.2f}, {sub['governance_composite'].max():.2f}]"
    r = run_gls(sub, 'ca_gdp', base_vars)
    if r is None:
        print(f"  Q{q}: insufficient obs")
        continue
    row = {
        'quintile': q,
        'gov_range': gov_range,
        'gov_mean': sub['governance_composite'].mean(),
        'n_countries': r['n_countries'],
        'n_obs': r['n_obs'],
        'r_squared': r['r_squared'],
        'pct_transition': sub['is_transition'].mean() * 100,
        'pct_cca': sub['is_cca'].mean() * 100,
    }
    for v in demo_vars:
        row[f'{v}_coef'] = r['coefficients'][v]
        row[f'{v}_se'] = r['std_errors'][v]
        row[f'{v}_pval'] = r['p_values'][v]
    quint_rows.append(row)
    print(f"  Q{q} {gov_range}: Z₁={row['Z_1_coef']:7.2f}{star(row['Z_1_pval'])} "
          f"(p={row['Z_1_pval']:.4f}), N_c={r['n_countries']}, "
          f"CCA%={row['pct_cca']:.1f}%, Trans%={row['pct_transition']:.1f}%")

quint_df = pd.DataFrame(quint_rows)
quint_df.to_csv(TABLE_DIR / "table8_governance_quintiles.csv", index=False)

md = ["# Table 8: Baseline by Governance Quintile\n"]
md.append("| Quintile | Gov range | Mean gov | N_c | Z₁ | SE | p-val | R² | %CCA | %Transition |")
md.append("|---|---|---|---|---|---|---|---|---|---|")
for _, row in quint_df.iterrows():
    md.append(f"| Q{row['quintile']:.0f} | {row['gov_range']} | {row['gov_mean']:.2f} | "
              f"{row['n_countries']} | {row['Z_1_coef']:.2f}{star(row['Z_1_pval'])} | "
              f"{row['Z_1_se']:.2f} | {row['Z_1_pval']:.4f} | {row['r_squared']:.4f} | "
              f"{row['pct_cca']:.1f}% | {row['pct_transition']:.1f}% |")
(TABLE_DIR / "table8_governance_quintiles.md").write_text("\n".join(md))

# ═══════════════════════════════════════════════════════════════════════════
# TABLE 9: MEDIATION TEST — DOES GOVERNANCE ATTENUATE Z₁?
# ═══════════════════════════════════════════════════════════════════════════
print("\n" + "=" * 70)
print("TABLE 9: MEDIATION / ATTENUATION TEST")
print("=" * 70)

# Step 1: Baseline Z₁ (without governance)
med_sample = est.dropna(subset=base_vars + ['ca_gdp', 'governance_composite']).copy()
r_base = run_gls(med_sample, 'ca_gdp', base_vars)

# Step 2: + governance as control
gov_control_vars = base_vars + ['governance_composite']
r_gov = run_gls(med_sample, 'ca_gdp', gov_control_vars)

# Step 3: + governance interactions
est_med = med_sample.copy()
est_med['gov_dm'] = est_med['governance_composite'] - est_med['governance_composite'].mean()
for z in demo_vars:
    est_med[f'{z}_x_gov'] = est_med[z] * est_med['gov_dm']
gov_full_vars = base_vars + ['gov_dm'] + [f'{z}_x_gov' for z in demo_vars]
r_gov_int = run_gls(est_med, 'ca_gdp', gov_full_vars)

# Compute attenuation
if r_base and r_gov:
    z1_base = r_base['coefficients']['Z_1']
    z1_gov = r_gov['coefficients']['Z_1']
    attenuation = (z1_base - z1_gov) / z1_base * 100
    print(f"  Baseline Z₁:        {z1_base:.2f}{star(r_base['p_values']['Z_1'])}")
    print(f"  + Governance ctrl:   {z1_gov:.2f}{star(r_gov['p_values']['Z_1'])}")
    print(f"  Attenuation:         {attenuation:.1f}%")

if r_gov_int:
    z1_int = r_gov_int['coefficients']['Z_1']
    attenuation_int = (z1_base - z1_int) / z1_base * 100
    print(f"  + Gov interactions:  {z1_int:.2f}{star(r_gov_int['p_values']['Z_1'])}")
    print(f"  Attenuation (int):   {attenuation_int:.1f}%")

# Shapley-style decomposition: compare R² contributions
# R²(full) - R²(without Z) = Z contribution
# R²(full) - R²(without gov) = gov contribution
base_no_z = [v for v in base_vars if v not in demo_vars]
r_no_z = run_gls(med_sample, 'ca_gdp', base_no_z + ['governance_composite'])
r_no_gov = run_gls(med_sample, 'ca_gdp', base_vars)

if r_gov and r_no_z and r_no_gov:
    r2_full = r_gov['r_squared']
    r2_no_z = r_no_z['r_squared']
    r2_no_gov = r_no_gov['r_squared']
    z_contrib = r2_full - r2_no_z
    gov_contrib = r2_full - r2_no_gov
    print(f"\n  Shapley R² decomposition:")
    print(f"    Full R²:      {r2_full:.4f}")
    print(f"    R²(no Z):     {r2_no_z:.4f}, Z contrib: {z_contrib:.4f}")
    print(f"    R²(no gov):   {r2_no_gov:.4f}, gov contrib: {gov_contrib:.4f}")

# Save mediation results
med_rows = [
    {"specification": "Baseline", "Z_1_coef": r_base['coefficients']['Z_1'],
     "Z_1_pval": r_base['p_values']['Z_1'], "r_squared": r_base['r_squared'],
     "n_obs": r_base['n_obs'], "n_countries": r_base['n_countries']},
    {"specification": "+ Governance control", "Z_1_coef": r_gov['coefficients']['Z_1'],
     "Z_1_pval": r_gov['p_values']['Z_1'], "r_squared": r_gov['r_squared'],
     "n_obs": r_gov['n_obs'], "n_countries": r_gov['n_countries'],
     "attenuation_pct": attenuation},
]
if r_gov_int:
    med_rows.append({
        "specification": "+ Governance interactions", "Z_1_coef": r_gov_int['coefficients']['Z_1'],
        "Z_1_pval": r_gov_int['p_values']['Z_1'], "r_squared": r_gov_int['r_squared'],
        "n_obs": r_gov_int['n_obs'], "n_countries": r_gov_int['n_countries'],
        "attenuation_pct": attenuation_int,
    })

pd.DataFrame(med_rows).to_csv(TABLE_DIR / "table9_mediation.csv", index=False)

md = ["# Table 9: Mediation / Attenuation Test\n"]
md.append("| Specification | Z₁ | p-val | R² | Attenuation |")
md.append("|---|---|---|---|---|")
for r in med_rows:
    att = f"{r.get('attenuation_pct', 0):.1f}%" if 'attenuation_pct' in r else "—"
    md.append(f"| {r['specification']} | {r['Z_1_coef']:.2f}{star(r['Z_1_pval'])} | "
              f"{r['Z_1_pval']:.4f} | {r['r_squared']:.4f} | {att} |")
(TABLE_DIR / "table9_mediation.md").write_text("\n".join(md))

print(f"\nAll Phase 3 tables saved to {TABLE_DIR}")
print("Phase 3 complete.")
